# Copyright 2024 CHINA MERCHANTS BANK CO., LTD.
# Copyright 2024 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
transform huggingface model to mindspore safetensors.
"""

import os
import json
import argparse
import time
from pathlib import Path

import mindspore as ms
import torch

DTYPE_MS = {
    "int8":ms.int8,
    "uint8": ms.uint8,
    "float16": ms.float16,
    "bfloat16": ms.bfloat16,
    "float32": ms.float32,
    'fp32': ms.float32,
    'bf16': ms.bfloat16,
    'fp16': ms.float16,
}


def name_replace(weight_name: str):
    """replace weight name"""
    weight_name = weight_name.replace('embed_tokens.', 'tok_embeddings.')
    weight_name = weight_name.replace('lm_head.', 'output.')
    weight_name = weight_name.replace('.self_attn.q_proj.', '.attention.wq.')
    weight_name = weight_name.replace('.self_attn.k_proj.', '.attention.wk.')
    weight_name = weight_name.replace('.self_attn.v_proj.', '.attention.wv.')
    weight_name = weight_name.replace('.self_attn.o_proj.', '.attention.wo.')
    weight_name = weight_name.replace('.self_attn.q_norm.', '.attention.q_norm.')
    weight_name = weight_name.replace('.self_attn.k_norm.', '.attention.k_norm.')
    weight_name = weight_name.replace('.mlp.gate_proj.', '.feed_forward.w1.')
    weight_name = weight_name.replace('.mlp.down_proj.', '.feed_forward.w2.')
    weight_name = weight_name.replace('.mlp.up_proj.', '.feed_forward.w3.')
    weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.')
    weight_name = weight_name.replace(
        '.post_attention_layernorm.', '.ffn_norm.')
    # Required for lora weight conversion
    weight_name = weight_name.replace('base_model.model.', '')
    weight_name = weight_name.replace('lora_A.weight', 'lora_a')
    weight_name = weight_name.replace('lora_B.weight', 'lora_b')
    return weight_name


# pylint: disable=W0613
def convert_hf_to_ms(input_path, output_path, dtype=None, **kwargs):
    """convert hf weight to ms."""
    print(
        f"Trying to convert huggingface checkpoint in '{input_path}'.", flush=True)
    try:
        from safetensors.torch import load_file
    except ImportError as e:
        raise ImportError(
            "Failed to load HuggingFace checkpoint. "
            "Please make sure the 'transformers' library is installed and available."
        ) from e

    try:
        ckpt_paths = sorted(Path(input_path).glob("*.safetensors"))
        dict_all = {}
        for ckpt_path in ckpt_paths:
            state_dict = load_file(ckpt_path, device='cpu')
            dict_all.update(state_dict)
        model_hf = dict(sorted(dict_all.items(), key=lambda x: x[0]))
    # pylint: disable=W0703
    except Exception as e:
        print(
            f"Do not find huggingface checkpoint in '{os.path.dirname(input_path)}', Error {e.message}.", flush=True)
        return False
    ckpt_list = []
    for name, value in model_hf.items():
        name = name_replace(name)
        if name == 'model.norm.weight':
            name = 'model.norm_out.weight'
        if name == 'output.weight':
            name = 'lm_head.weight'
        if name == 'model.tok_embeddings.weight':
            name = 'model.tok_embeddings.embedding_weight'
        value = value.detach().to(torch.float32).numpy()
        print(name, value.shape)
        ckpt_list.append({'name': name, 'data': ms.Tensor(value, dtype=dtype)})

    ms.save_checkpoint(ckpt_list, output_path, format="safetensors")
    print(f"\rConvert huggingface checkpoint finished, the mindspore checkpoint is saved in '{output_path}'.",
          flush=True)
    return True


def convert_lora_to_ms(input_path, output_path, dtype=None, **kwargs):
    """convert hf weight to ms."""
    print(
        f"Trying to convert huggingface checkpoint in '{input_path}'.", flush=True)
    try:
        from safetensors.torch import load_file
    except ImportError as e:
        raise ImportError(
            "Failed to load HuggingFace checkpoint. "
            "Please make sure the 'safetensors' library is installed and available."
        ) from e

    try:
        ckpt_paths = sorted(Path(input_path).glob("adapter_model.safetensors"))
        dict_all = {}
        for ckpt_path in ckpt_paths:
            state_dict = load_file(ckpt_path, device='cpu')
            dict_all.update(state_dict)
        model_hf = dict(sorted(dict_all.items(), key=lambda x: x[0]))
    # pylint: disable=W0703
    except Exception as e:
        print(
            f"Do not find huggingface checkpoint in '{os.path.dirname(input_path)}', Error {e.message}.", flush=True)
        return False
    ckpt_list = []
    for name, value in model_hf.items():
        name = name_replace(name)
        if name == 'model.norm.weight':
            name = 'model.norm_out.weight'
        if name == 'output.weight':
            name = 'lm_head.weight'
        if name == 'model.tok_embeddings.weight':
            name = 'model.tok_embeddings.embedding_weight'
        value = value.detach().numpy()
        print(name, value.shape)
        ckpt_list.append({'name': name, 'data': ms.Tensor(value, dtype=dtype)})

    ms.save_checkpoint(ckpt_list, output_path, format="safetensors")
    print(f"\rConvert huggingface checkpoint finished, the mindspore checkpoint is saved in '{output_path}'.",
          flush=True)
    convert_lora_config(input_path, kwargs['align_rank'])
    return True


def convert_lora_config(input_path, align_rank):
    """modified config.json 'r' and 'target_modules' """
    config_path = os.path.join(input_path, "adapter_config.json")
    replace_rules = {
        'q_proj': 'wq',
        'k_proj': 'wk',
        'v_proj': 'wv',
        'o_proj': 'wo',
        'gate_proj': 'w1',
        'down_proj': 'w2',
        'up_proj': 'w3'
    }
    try:
        with open(config_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
        modified = False
        # r must be a multiple of 16 on Atlas 300V Pro
        if align_rank:
            if data["r"] % 16 != 0:
                data["r"] = (data["r"] // 16 + 1) * 16
                print("r modification successful, the configuration file has been updated!")
            else:
                print("The configuration r has already been modified, no need to modify it again.")

        for i, name in enumerate(data["target_modules"]):
            if name not in replace_rules.keys():
                print(f"target_modules {name} does not need to be modified")
            else:
                data["target_modules"][i] = replace_rules[name]
                print(f"target_modules {name} has been modified to {replace_rules[name]}")
                modified = True

        if modified:
            with open(config_path, 'w', encoding='utf-8') as file:
                json.dump(data, file, indent=4)
            print("Target_modules modification successful, the configuration file has been updated!")
        else:
            print("The configuration target_modules has already been modified, no need to modify it again.")

        with open(config_path, 'w', encoding='utf-8') as file:
            json.dump(data, file, indent=4)
        print(f"JSON file modified successfully!")

    except FileNotFoundError:
        print(f"Error: File {file} does not exist")
    except KeyError:
        print("Error: The specified key does not exist in the JSON")
    except json.JSONDecodeError:
        print("Error: File content is not valid JSON format")


def convert_weight(para):
    """convert weight entrance"""
    if not hasattr(para, 'hf_safetensors_path'):
        para.hf_safetensors_path = para.input_path
    if not hasattr(para, 'ms_safetensors_path'):
        para.ms_safetensors_path = para.output_path
    if not para.dtype:
        para.dtype = "bf16"
    dtype = DTYPE_MS.get(para.dtype)
    if para.is_lora:
        convert_lora_to_ms(input_path=para.hf_safetensors_path, output_path=para.ms_safetensors_path,
                           dtype=dtype, align_rank=para.align_rank)
    else:
        convert_hf_to_ms(input_path=para.hf_safetensors_path,
                         output_path=para.ms_safetensors_path, dtype=dtype)


if __name__ == "__main__":
    start = time.time()
    parser = argparse.ArgumentParser()
    parser.add_argument('--hf_safetensors_path', required=True, default='/path/huggingface_dir')
    parser.add_argument('--ms_safetensors_path', default='transform.safetensors')
    parser.add_argument('--dtype', default='bf16')
    parser.add_argument('--is_lora', default=False)
    parser.add_argument('--align_rank', default=False)
    args = parser.parse_args()

    convert_weight(args)
    end = time.time()
    print('time:', end - start, flush=True)
